# Plotting code for confidence intervals
import json
import os
import matplotlib.pyplot as plt
import numpy as np
import wandb
from matplotlib.ticker import ScalarFormatter, LogFormatter

def load_results(job_id, results_directory='/directory/with/results'):
    json_file = os.path.join(results_directory, f'aggregated_results_{job_id}.json')
    with open(json_file, 'r') as f:
        return json.load(f)

def plot_results(job_ids, colors, labels, metric, title, xlabel, ylabel, log_every=5, start_idx=0, end_idx=None, 
                 figsize=(10, 6), title_fontsize=20, label_fontsize=16, legend_fontsize=14, tick_fontsize=12,
                 save_directory='/directory/save/confidence_intervals', 
                 wandb_project='Confidence_Interval', wandb_entity='some-agents', run_name='custom_run',
                 y_min=None, y_max=None, line_thickness=2, markers=None, max_markers=20, 
                 x_logarithm = False, y_logarithm = False, manual_ticks = False):
    
    try:
        os.makedirs(save_directory, exist_ok=True)
        plt.figure(figsize=figsize)
        
        for i, job_id in enumerate(job_ids):
            results = load_results(job_id)
            mean = np.array(results[metric]['mean'])
            lower_interval = np.array(results[metric]['interval'][0])
            upper_interval = np.array(results[metric]['interval'][1])
            
            if end_idx is None:
                end_idx = len(mean)
            
            x = np.arange(start_idx, end_idx)

            mean = mean[start_idx:end_idx]
            lower_interval = lower_interval[start_idx:end_idx]
            upper_interval = upper_interval[start_idx:end_idx]

            marker = markers[i] if markers is not None else None
            interval = max(1, len(x) // max_markers)
            plt.plot(x * log_every, mean, color=colors[i], label=labels[i], linewidth=line_thickness, marker=marker, markevery=interval)
            plt.fill_between(x * log_every, lower_interval, upper_interval, color=colors[i], alpha=0.3)

            if x_logarithm:
                plt.xscale('log')
            if y_logarithm:
                plt.yscale('log')
        
        plt.title(title, fontsize=title_fontsize)
        plt.xlabel(xlabel, fontsize=label_fontsize)
        plt.ylabel(ylabel, fontsize=label_fontsize)
        plt.legend(fontsize=legend_fontsize)
        plt.grid(True)
        plt.tick_params(axis='both', which='both', labelsize=tick_fontsize)
        
        if y_min is not None and y_max is not None:
            plt.ylim(y_min, y_max)
        
        if manual_ticks:
            ax = plt.gca()
            ax.xaxis.set_major_formatter(plt.ScalarFormatter())
            ax.yaxis.set_major_formatter(plt.ScalarFormatter())
            ax.xaxis.set_ticks([30, 40, 60, 100, 200]) 
            ax.yaxis.set_ticks([2, 3, 4])  

        plt.tight_layout()
        pdf_file = os.path.join(save_directory, 'plot.pdf')
        plt.savefig(pdf_file, format='pdf')
        plt.show()

        os.environ['WANDB_DIR'] = '/directory/for/wandb_log/'
        wandb.init(project=wandb_project, entity=wandb_entity, name=run_name)
        artifact = wandb.Artifact('plot_pdf', type='report')
        artifact.add_file(pdf_file)
        wandb.log_artifact(artifact)
        print("Plot successfully saved and uploaded to Weights and Biases.")
    
    except Exception as e:
        print(f"An error occurred: {e}")

# Integrates with Weights and Biases & Slurm. Input Job numbers for confidence intervals.
job_ids = ['328339', '328336', '328332', '328325', '328327', '328321']
colors = ['black', 'blue', 'cyan', 'green', 'magenta', 'red']
labels = ['FedAvg', 'FedAdaGrad', 'FedAdam', 'Direct Joint Adap.', 'Joint Adap. w/o\nPrecond. Commu.', 'FedAda$^2$']
metric = 'pseudogradient l2 norm standard deviation'
run_name = 'CI_DP_test0'
markers = ['x', 'v', '<', 's', 'd', 'D']

plot_results(job_ids, colors, labels, metric=metric, title='Pseudo-Gradient',
             xlabel='Communication Rounds', ylabel='$\ell_2$-Norm SD', log_every=5, 
             start_idx=5, end_idx=40, figsize=(7, 6), title_fontsize=45, label_fontsize=33,
             legend_fontsize=18, tick_fontsize=18, wandb_entity='some-agents', 
             run_name=run_name, line_thickness=2, markers=markers, max_markers=8, 
             x_logarithm = False, y_logarithm = False, manual_ticks = False)

